
import jax
from .jax_types import Params
from functools import partial

@partial(jax.jit, static_argnames=("ema_decay",))
def jit_ema_update(src_params: Params, tar_params: Params, ema_decay: float):
    return jax.tree_util.tree_map(lambda p, tp: p * (1 - ema_decay) + tp * ema_decay, 
                                  src_params, 
                                  tar_params)